archived/bandits_statlog_vw_customEnv/src/io_utils.py (112 lines of code) (raw):
import json
import logging
import os
import shutil
from pathlib import Path
import boto3
import pandas as pd
logger = logging.getLogger(__name__)
def validate_experience(experience):
"""
Validate the collected experience has required keys.
"""
keys = ["observation", "action_prob", "action", "reward"]
for key in keys:
is_valid = key in experience
if not is_valid:
return False
return True
class CSVReader:
"""Reader object that loads experiences from CSV file chunks.
The input files will be read from in an random order."""
def __init__(self, input_files):
self.files = input_files
def get_iterator(self):
for file in self.files:
reader = pd.read_csv(file, chunksize=1000)
for df in reader:
df_no_nans = df.dropna()
for line in df_no_nans.iterrows():
line_dict = line[1].to_dict()
yield line_dict
class JsonLinesReader:
"""Reader object that loads experiences from JSON file chunks.
The input files will be read from in an random order."""
def __init__(self, input_files):
self.files = input_files
self.cur_file = None
self.cur_index = 0
self.max_index = len(input_files) - 1
self.done = False
def get_experience(self):
line = self._next_line()
experience = self._try_parse(line)
while not experience and not self.done:
logger.debug("Skipping empty line in {}".format(self.cur_file))
experience = self._try_parse(self._next_line())
return experience
def _try_parse(self, line):
if line is None or line.strip() == "":
return None
try:
line_json = json.loads(line.strip())
assert "observation" in line_json, "observation not found in record"
assert "action" in line_json, "action not found in record"
assert "reward" in line_json, "reward not found in record"
assert "prob" in line_json, "prob not found in record"
return line_json
except Exception:
logger.exception("Ignoring corrupt json record in {}: {}".format(self.cur_file, line))
return None
def _next_line(self):
if not self.cur_file:
self.cur_file = self._next_file()
if self.done is True:
return None
line = self.cur_file.readline()
tries = 0
while not line and tries < 100:
tries += 1
self.cur_file.close()
self.cur_file = self._next_file()
if self.done is True:
return None
line = self.cur_file.readline()
if not line:
logger.debug("Ignoring empty file {}".format(self.cur_file))
if not line:
raise ValueError("Failed to read next line from files: {}".format(self.files))
return line
def _next_file(self):
if self.cur_index > self.max_index:
self.done = True
return None
path = self.files[self.cur_index]
self.cur_index += 1
return open(path, "r")
def get_vw_model(disk_path=None):
"""
Returns a tuple (str, str) of metadata string and model weights URL on disk.
"""
sagemaker_model_path = Path(disk_path)
meta_files = list(sagemaker_model_path.rglob("vw.metadata"))
if len(meta_files) == 0:
raise ValueError("Algorithm Error: 'vw.metadata' not found in model files.")
metadata_path = meta_files[0]
model_files = list(sagemaker_model_path.rglob("vw.model"))
if len(model_files) == 0:
raise ValueError("Algorithm Error: 'vw.model' not found in model files.")
model_path = model_files[0]
return metadata_path.as_posix(), model_path.as_posix()
def extract_model(tar_gz_folder):
"""
This function extracts the model.tar.gz and then
returns a tuple (str, str) of metadata string and model weights URL on disk.
"""
shutil.unpack_archive(
filename=os.path.join(tar_gz_folder, "model.tar.gz"), extract_dir=tar_gz_folder
)
return get_vw_model(tar_gz_folder)
def parse_s3_uri(uri):
uri = uri.replace("s3://", "")
bucket, *key = uri.split("/")
file_name = key[-1]
key = "/".join(key)
return bucket, key, file_name
def download_manifest_data(manifest_file_path, output_dir):
"""
Download the s3 files contained in a manifest file.
"""
with open(manifest_file_path.as_posix()) as f:
manifest = json.load(f)
s3_prefix = manifest[0]["prefix"]
s3 = boto3.client("s3")
for file in manifest[1:]:
s3_uri = os.path.join(s3_prefix, file)
bucket, key, file_name = parse_s3_uri(s3_uri)
output_file = os.path.join(output_dir.as_posix(), file_name)
s3.download_file(bucket, key, output_file)
print("Downloaded file ", output_file)